import os
import time
import argparse

# TODO please configure TORCH_HOME and data_paths before running
TORCH_HOME = "/home/chenxinyu/Codes/NAS/ETENAS-master/NAS_data/"  # Path that contains the nas-bench-201 database. If you only want to run on NASNET (i.e. DARTS) search space, then just leave it empty
data_paths = {
    "cifar10": "/home/chenxinyu/Codes/NAS/ETENAS-master/NAS_data/cifar.python",
    "cifar100": "/home/chenxinyu/Codes/NAS/ETENAS-master/NAS_data/cifar.python",
    "ImageNet16-120": "/home/chenxinyu/data/ImageNet-16-120",
    "imagenet-1k": "/home/chenxinyu/data/ImageNet-16-120",
}


parser = argparse.ArgumentParser("TENAS_launch")
parser.add_argument('--k', type=int, default=8)
parser.add_argument('--gpu', default=1, type=int, help='use gpu with cuda number')
parser.add_argument('--space', default='nas-bench-201', type=str, choices=['nas-bench-201', 'darts'], help='which nas search space to use')
parser.add_argument('--dataset', default='cifar100', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120', 'imagenet-1k'], help='Choose between cifar10/100/ImageNet16-120/imagenet-1k')
parser.add_argument('--seed', default=0, type=int, help='manual seed')
args = parser.parse_args()


##### Basic Settings
precision = 3
# init = 'normal'
# init = 'kaiming_uniform'
init = 'kaiming_normal'


if args.space == "nas-bench-201":
    prune_number = 1
    # batch_size = 72
    batch_size = 64
    space = "nas-bench-201"  # different spaces of operator candidates, not structure of supernet
    super_type = "basic"  # type of supernet structure
elif args.space == "darts":
    space = "darts"
    super_type = "nasnet-super"
    if args.dataset == "cifar10" or args.dataset == "cifar100":
        prune_number = 3
        # batch_size = 14
        batch_size = 64
    elif args.dataset == "imagenet-1k":
        prune_number = 2
        # batch_size = 24
        batch_size = 64


timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))


core_cmd = "CUDA_VISIBLE_DEVICES={gpuid} OMP_NUM_THREADS=4 python ./prune_ntknas_1.py \
--save_dir {save_dir} --max_nodes {max_nodes} \
--dataset {dataset} \
--data_path {data_path} \
--search_space_name {space} \
--super_type {super_type} \
--arch_nas_dataset {TORCH_HOME}/NAS-Bench-201-v1_1-096897.pth \
--track_running_stats 1 \
--workers 0 --rand_seed {seed} \
--timestamp {timestamp} \
--precision {precision} \
--init {init} \
--repeat 3 \
--batch_size {batch_size} \
--prune_number {prune_number} \
".format(
    gpuid=args.gpu,
    save_dir="./a/ntk_1/prune-{space}/{dataset}".format(space=space, dataset=args.dataset),
    max_nodes=4,
    data_path=data_paths[args.dataset],
    dataset=args.dataset,
    TORCH_HOME=TORCH_HOME,
    space=space,
    super_type=super_type,
    seed=args.seed,
    timestamp=timestamp,
    precision=precision,
    init=init,
    batch_size=batch_size,
    prune_number=prune_number,
)

os.system(core_cmd)
